#!/usr/bin/env python3
import argparse
import json
import os
import gc
import numpy as np
import pandas as pd
import torch
import time
import math

from transformers import AutoModelForCausalLM, AutoTokenizer
# We assume that the environment has the vllm package installed.
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import destroy_model_parallel
from models.system_prompts import LLAMA2_CLS_PROMPT
from models.language_models import Llama2_7b, Vicuna13b, Mistral7b, PhiMini, Llama3_8b, Qwen7b, Gemma7b, Zephyr7bR2D2, Mistral7B_RR, Llama3_8bRR


def harmbench_judge_fn(classifier, prompts: list, responses: list) -> list:
    """Uses the HarmBench classifier model (via vllm) to evaluate candidate responses."""

    classifier_params = SamplingParams(temperature=0.0, max_tokens=1)
    inputs = [
        LLAMA2_CLS_PROMPT['prompt'].format(behavior=prompt, generation=response)
        for prompt, response in zip(prompts, responses)
    ]
 
    outputs = classifier.generate(inputs, classifier_params, use_tqdm=False)
    preds = [o.outputs[0].text.strip() for o in outputs]

    labels = ["yes", "no"]
    for pred in preds:
        if pred.lower() not in labels:
            print("WARNING: Detected invalid label in HarmBench evaluation. Defaulting to 'no' for invalid labels.")
    
    classifications = [int(pred.lower() == "yes") for pred in preds]
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    return classifications

# ---------------------------------------------------------------------
# Dictionary mapping model keys to model classes.
# ---------------------------------------------------------------------

model_dict = {
    "mistral7b": Mistral7b,
    "llama2": Llama2_7b,
    "llama3": Llama3_8b,
    "phi": PhiMini,
    "vicuna": Vicuna13b,
    "gemma": Gemma7b,
    "qwen": Qwen7b,
    "zephyr7b": Zephyr7bR2D2,
    "mistral7": Mistral7B_RR,
    "llama3rr": Llama3_8bRR,
}

# ---------------------------------------------------------------------
# Helper function to look up the semantic category from the CSV.
# ---------------------------------------------------------------------
def get_category(target_prompt: str, csv_df: pd.DataFrame) -> str:
    # First try an exact match in the "Behavior" column.
    matches = csv_df[csv_df['Behavior'] == target_prompt]
    if len(matches) > 0:
        return matches.iloc[0]['SemanticCategory']
    # Alternatively, try substring matching (case insensitive).
    for idx, row in csv_df.iterrows():
        if target_prompt.lower() in str(row['Behavior']).lower():
            return row['SemanticCategory']
    return "Unknown"

# ---------------------------------------------------------------------
# Main processing function: generation then evaluation.
# ---------------------------------------------------------------------
def process_evaluation(completions_json_path: str, csv_path: str, evaluation_path: str, device):
    # Load the completions JSON.
    with open(completions_json_path, "r") as f:
        completions_data = json.load(f)
    # Load the CSV file with semantic categories.
    csv_df = pd.read_csv(csv_path)

    
    # Infer model key from the JSON file name.
    fname = os.path.basename(completions_json_path).lower()
    model_key = None
    for key in model_dict.keys():
        if key in fname:
            model_key = key
            break
    if model_key is None:
        raise ValueError("Could not infer the model from the JSON file name.")
    
    print(f"Inferred model key '{model_key}' from filename '{fname}'.")
    
    # Instantiate the generation model.
    model = model_dict[model_key](device=device)
    model.tokenizer.pad_token = model.tokenizer.eos_token
    if model.tokenizer.pad_token is None:
        print("The model tokenizer does not have a pad token.")
        model.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
        model.tokenizer.padding_side = "left"

    print(model.tokenizer.pad_token)
    # ---------------------------
    # Generation Phase: Create all model responses.
    # ---------------------------
    generation_data = []  # To store candidate responses per sample.
    for idx in sorted(completions_data.keys(), key=lambda x: int(x)):
        sample = completions_data[idx]
        o = sample.get("optimization", [])
        if not o:
            continue  # Skip samples with no optimization prompts
        
        # Expect o to be a list of lists; take the first element of each candidate.
        optim_prompts = [p[0] for p in o]
    
        # Prepare batch inputs.
        batch_inputs = []
        for p in optim_prompts:
            prompt = model._get_prompt(prompt=p)
            batch_inputs.append(prompt)
        
        # Tokenize and generate responses.
        inputs = model.tokenizer(batch_inputs, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.model.generate(**inputs, max_new_tokens=256, do_sample=False)
        generated_texts = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        responses = []
        for inp, gen in zip(batch_inputs, generated_texts):
            # Remove the input prompt if present.
            if inp in gen:
                #response_text = gen[len(inp):].strip()
                response_text = gen.split(inp)[-1].strip()
            else:
                response_text = gen.strip()
            
            if "assistant\n\n" in response_text:
                response_text = response_text.split("assistant\n\n")[-1].strip()

            if "\n<|assistant|>\n " in response_text:
                response_text = response_text.split("\n<|assistant|>\n")[-1].strip()
            
            if "\nmodel\n" in response_text:
                response_text = response_text.split("\nmodel\n")[-1].strip()
            
            if "\n<|assistant|>\n" in response_text:
                response_text = response_text.split("\n<|assistant|>\n")[-1].strip()
            
            responses.append(response_text)
            
        
        generation_data.append({
            "sample_idx": idx,
            "optim_prompts": optim_prompts,  # original candidate prompts
            "responses": responses,
        })
        print(f"Generated responses for sample {idx}.")
        with open(f'./completions_{model_key}_HB.json', "w") as out_f:
            json.dump(generation_data, out_f, indent=4)
    
    # After generation, free the model resources.
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    # ---------------------------
    # Evaluation Phase: Evaluate all generated responses.
    # ---------------------------
    classifier = LLM(model='cais/HarmBench-Llama-2-13b-cls', tensor_parallel_size=1)
    classifier.llm_engine.tokenizer.truncation_side = "left"

    results = []
    for data in generation_data:
        optim_prompts = data["optim_prompts"]   # list of candidate prompts
        responses = data["responses"]           # corresponding candidate responses

        prompts = [optim_prompts[0]] * len(optim_prompts)
        print(prompts)
        # Evaluate the candidate responses using the harmbench judge.
        classifications = harmbench_judge_fn(classifier, prompts, responses)
        
        selected_response = ""

        asr = 0
        
        # If any candidate is classified as a jailbreak instance, select it.
        for i, (cand_response, cls) in enumerate(zip(responses, classifications)):
            if cls == 1:
                print("Jailbreak detected:", cand_response)
                selected_response = cand_response
                asr = 1
                final_adv = optim_prompts[i]
                break
        # If no candidate is flagged, default to the last candidate.
        if selected_response == "":
            asr = 0
            selected_response = responses[-1]
            final_adv = optim_prompts[-1]
        
        # Use the first optimization prompt as the base prompt.
        first_prompt = optim_prompts[0]
        category = get_category(first_prompt, csv_df)
        result_entry = {
            "category": category,
            "prompt": first_prompt,
            "final_adv": final_adv,
            "response": selected_response,
            "is_jailbreak_harmbench": asr,
        }
        results.append(result_entry)
        print(f"Processed sample {data['sample_idx']}: category='{category}', prompt='{first_prompt[:50]}...', response_len={len(selected_response)}")
    
        # Save the final evaluation results.
        with open(evaluation_path, "w") as out_f:
            json.dump(results, out_f, indent=4)
        print(f"Evaluation results saved at {evaluation_path}")

# ---------------------------------------------------------------------
# Main: argument parsing and entry point.
# ---------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Evaluate HF model completions against jailbreak behavior using harmbench judge"
    )
    parser.add_argument("--optimization_path", type=str, required=True,
                        help="Path to the JSON file containing the completions (dictionary with keys '0'...'n').")
    parser.add_argument("--csv_path", type=str, required=True,
                        help="Path to the CSV file containing semantic categories (must contain 'Behavior' and 'SemanticCategory' columns).")
    parser.add_argument("--evaluation_path", type=str, default="./evaluation_results.json",
                        help="Path to save the evaluation results as a JSON file.")
    
    args = parser.parse_args()
    
    process_evaluation(args.optimization_path, args.csv_path, args.evaluation_path, device='cuda')
